import os.path
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import matplotx

mysigmalist = [1, 3]
myNlist = [50,100,200,500]
mylengthlist = [0.5, 1,2]
mysigmakernel = [1,2,4]

epsilon_all = np.array([1e-3, 2e-3, 3.3e-3, 5e-3, 6.7e-3, 8e-3, 1e-2, 2e-2, 3.3e-2, 5e-2, 6.7e-2, 8e-2,
                                1e-1,0.125,0.15,0.175, 2e-1, 3.3e-1, 6.7e-1, 8e-1, 1])
XBCF_continuous_result = [[None]*4,[None]*4, [None]*4]
XBCF_continuous_result[0][0] = np.load('Data/Continuous/XBCF_50_1_.npy')
XBCF_continuous_result[0][1] = np.load('Data/Continuous/XBCF_100_1_.npy')
XBCF_continuous_result[0][2] = np.load('Data/Continuous/XBCF_200_1_.npy')
XBCF_continuous_result[0][3] = np.load('Data/Continuous/XBCF_500_1_.npy')
XBCF_continuous_result[1][0] = np.load('Data/Continuous/XBCF_50_2_.npy')
XBCF_continuous_result[1][1] = np.load('Data/Continuous/XBCF_100_2_.npy')
XBCF_continuous_result[1][2] = np.load('Data/Continuous/XBCF_200_2_.npy')
XBCF_continuous_result[1][3] = np.load('Data/Continuous/XBCF_500_2_.npy')
XBCF_continuous_result[2][0] = np.load('Data/Continuous/XBCF_50_3_.npy')
XBCF_continuous_result[2][1] = np.load('Data/Continuous/XBCF_100_3_.npy')
XBCF_continuous_result[2][2] = np.load('Data/Continuous/XBCF_200_3_.npy')
XBCF_continuous_result[2][3] = np.load('Data/Continuous/XBCF_500_3_.npy')
XBCF_continuous_result_mean = [[None]*4,[None]*4, [None]*4]
XBCF_continuous_result_quantile = [[None]*4,[None]*4, [None]*4]
for i in range(3):
    for j in range(4):
        if XBCF_continuous_result[i][j] is not None:
            XBCF_continuous_result_mean[i][j] = XBCF_continuous_result[i][j].mean(axis=0)
            XBCF_continuous_result_quantile[i][j] = np.quantile(XBCF_continuous_result[i][j], q=[0.1,0.25, 0.5,0.75, 0.9], axis=0)

GP_continuous_result = [[[[None]*4,[None]*4] for s in range(3)] for t in range(3)]
GP_continuous_result_mean = [[[[None]*4,[None]*4] for s in range(3)] for t in range(3)]
GP_continuous_result_quantile = [[[[None]*4,[None]*4] for s in range(3)] for t in range(3)]

for g in range(3):
    for h in range(3):
        for i in range(2):
            for j in range(4):
                sigmakernel = mysigmakernel[g]
                length_scale = mylengthlist[h]
                sigma = mysigmalist[i]
                N = myNlist[j]
                mydataname = "_".join(["GP", str(N), str(sigma), str(length_scale),str(sigmakernel), '.npy'])
                if os.path.exists('Data/Continuous/' + mydataname):
                    GP_continuous_result[g][h][i][j] = np.load('Data/Continuous/' + mydataname)
                    GP_continuous_result_mean[g][h][i][j] = GP_continuous_result[g][h][i][j].mean(axis=0)
                    GP_continuous_result_quantile[g][h][i][j] = np.quantile(GP_continuous_result[g][h][i][j], q=[0.05,0.25, 0.5,0.75, 0.95], axis=0)

# Plot the quantiles of the ACRisk and
def plotquantile_XBCF_num1():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('$90\%$ quantile of the ACRisk', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.34)
    ax1.set_ylim(0,0.39)
    for i, j, temp in [(1, 0, 0), (1, 1, 1),(1,2,2), (1, 3, 3)]:
        if XBCF_continuous_result_quantile[i][j][-1,:,:] is not None:
            NName = ['50', '100', '200', '500'][j]
            SigmaName = ['High signal', 'Low signal'][i]
            mylinetype = ['dotted','solid','dashed','dashdot'][temp]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            ax1.plot(epsilon_all[:-2], XBCF_continuous_result_quantile[i][j][-1,0, :-2], c=mycolor, label = 'n='+NName, linestyle=mylinetype, linewidth=3)
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("ACRisk with different sample size (low signal-to-noise ratio)", fontsize=28)
    fig1.legend(loc='upper left', bbox_to_anchor=(0.12, 0.89), fontsize=22)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/'+'XBCF_Continuous_num1'+'.pdf')
def plotquantile_continuous_xbcf_signal1():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('$90\%$ quantile of the ACRisk', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.39)
    # ax1.axhline(y=0, linestyle='--',color='black',linewidth=3)
    for i,j,temp in [(0,1,0),(1,1,1),(2,1,2)]:
        if XBCF_continuous_result_quantile[i][j][-1,:,:] is not None:
            NName = ['50', '100', '200', '500'][j]
            SigmaName = ['High','Medium', 'Low'][i]
            label_end = NName + ' observations, ' + SigmaName
            mylinetype = ['dotted','solid','dashed','dashdot'][temp]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            ax1.plot(epsilon_all[:-2], XBCF_continuous_result_quantile[i][j][-1,0, :-2], c=mycolor, label = SigmaName, linestyle=mylinetype, linewidth=3)
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("ACRisk with different signal strength (n=100)", fontsize=28)
    fig1.legend(loc='upper left', bbox_to_anchor=(0.12, 0.89), fontsize=22)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/'+'XBCF_Continuous_signal1'+'.pdf')
def plotquantile_continuous_gp_l1():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('$90\%$ quantile of the ACRisk', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.39)
    # ax1.axhline(y=0, linestyle='--',color='black',linewidth=3)
    for g, h, i, j, temp in [(1, 0, 1, 2, 0), (1, 1, 1, 2, 1), (1, 2, 1, 2, 2)]:
        if GP_continuous_result_mean[g][h][i][j] is not None:
            PriorName = ['strong', 'medium', 'weak'][g]
            NName = ['50', '100', '200', '500'][j]
            SigmaName = ['High signal', 'Low signal'][i]
            LengthName = ['l=0.5', 'l=1', 'l=2'][h]
            label_end = LengthName + ', ' + PriorName + ' prior for extrapolation'
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_continuous_result_quantile[g][h][i][j][-1,0, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label=LengthName)
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("Tail distribution of the ACRisk ($\sigma_0=2)$", fontsize=28)
    fig1.legend(loc='upper left', bbox_to_anchor=(0.12, 0.89), fontsize=22)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/'+'GP_Continuous_l1'+'.pdf')
def plotquantile_continuous_gp_sigma01():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('$90\%$ quantile of the ACRisk', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.39)
    # ax1.axhline(y=0, linestyle='--',color='black',linewidth=3)
    for g,h,i,j,temp in [(0,2,1,2,0),(1,2,1,2,1),(2,2,1,2,2)]:
        if GP_continuous_result_mean[g][h][i][j] is not None:
            PriorName = ['$\sigma_0=1$', '$\sigma_0=2$', '$\sigma_0=4$'][g]
            NName = ['50', '100', '200', '500'][j]
            SigmaName = ['High signal', 'Low signal'][i]
            LengthName = ['l=0.5', 'l=1', 'l=2'][h]
            label_end = LengthName + ', ' + PriorName + ' prior for extrapolation'
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_continuous_result_quantile[g][h][i][j][-1,0, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label=PriorName)
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("Tail distribution of the ACRisk (l=2)", fontsize=28)
    # plt.ylabel('average conditional risk ')
    fig1.legend(loc='upper left', bbox_to_anchor=(0.12, 0.89), fontsize=22)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/'+'GP_Continuous_sigma01'+'.pdf')

plotquantile_XBCF_num1()
plotquantile_continuous_xbcf_signal1()
plotquantile_continuous_gp_l1()
plotquantile_continuous_gp_sigma01()


mysigmakernel = [1,2,4] #g
mylengthlist = [0.5, 1,2] #h
mystrengthlist = [3, 6] #i
myNlist = [50,100,200] #j

GP_binary_result = [[[[None]*3,[None]*3] for s in range(3)] for t in range(3)]
GP_binary_result_mean = [[[[None]*3,[None]*3] for s in range(3)] for t in range(3)]
GP_binary_result_quantile = [[[[None]*3,[None]*3] for s in range(3)] for t in range(3)]

for g in range(3):
    for h in range(3):
        for i in range(2):
            for j in range(3):
                sigmakernel = mysigmakernel[g]
                length_scale = mylengthlist[h]
                strength = mystrengthlist[i]
                N = myNlist[j]
                mydataname = "_".join(["GP", str(N), str(length_scale),str(sigmakernel),str(strength), '.npy'])
                if os.path.exists('Data/Binary/Random/' + mydataname):
                    GP_binary_result[g][h][i][j] = np.load('Data/Binary/Random/' + mydataname)
                    GP_binary_result_mean[g][h][i][j] = GP_binary_result[g][h][i][j].mean(axis=0)
                    GP_binary_result_quantile[g][h][i][j] = np.quantile(GP_binary_result[g][h][i][j], q=[0.1,0.25, 0.5,0.75, 0.9], axis=0)

def plot_binary_gp_strength1():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average ACRisk of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.149)
    for g, h, i, j, temp in [(2, 0, 0, 1, 0), (2, 0, 1, 1, 1),(2, 0, 0, 2, 2), (2, 0, 1, 2, 3)]:
        if GP_binary_result_mean[g][h][i][j] is not None:
            SigmaName = ['Low signal', 'High signal'][i]
            NName = ['50', '100', '200', '500'][j]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'solid', 'dashed', 'dashed'][temp]
            ax1.plot(epsilon_all[:-2], GP_binary_result_mean[g][h][i][j][0, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label='n='+NName+', '+SigmaName)
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("ACRisk with different sample size and signal strength", fontsize=28)
    fig1.legend(loc='upper left', bbox_to_anchor=(0.6, 0.39), fontsize=22)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/' + 'GP_binary_numsignal1' + '.pdf')
def plot_binary_gp_strength2():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average Value of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0.5,0.529)
    for g, h, i, j, temp in [(2, 0, 0, 1, 0), (2, 0, 1, 1, 1),(2, 0, 0, 2, 2), (2, 0, 1, 2, 3)]:
        if GP_binary_result_mean[g][h][i][j] is not None:
            SigmaName = ['Low signal', 'High signal'][i]
            NName = ['50', '100', '200', '500'][j]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'solid', 'dashed', 'dashed'][temp]
            ax1.plot(epsilon_all[:-2], GP_binary_result_mean[g][h][i][j][1, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label='n='+NName+', '+SigmaName)
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("Value with different sample size and signal strength", fontsize=28)
    fig1.legend(loc='upper left', bbox_to_anchor=(0.6, 0.92), fontsize=22)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/' + 'GP_binary_numsignal2' + '.pdf')

plot_binary_gp_strength1()
plot_binary_gp_strength2()


GP_binary_result = [[[[None]*3,[None]*3] for s in range(3)] for t in range(3)]
GP_binary_result_mean = [[[[None]*3,[None]*3] for s in range(3)] for t in range(3)]
GP_binary_result_quantile = [[[[None]*3,[None]*3] for s in range(3)] for t in range(3)]

for g in range(3):
    for h in range(3):
        for i in range(2):
            for j in range(3):
                sigmakernel = mysigmakernel[g]
                length_scale = mylengthlist[h]
                strength = mystrengthlist[i]
                N = myNlist[j]
                mydataname = "_".join(["GP", str(N), str(length_scale),str(sigmakernel),str(strength), '.npy'])
                if os.path.exists('Data/Binary/Deterministic/' + mydataname):
                    GP_binary_result[g][h][i][j] = np.load('Data/Binary/Deterministic/' + mydataname)
                    GP_binary_result_mean[g][h][i][j] = GP_binary_result[g][h][i][j].mean(axis=0)
                    GP_binary_result_quantile[g][h][i][j] = np.quantile(GP_binary_result[g][h][i][j], q=[0.1,0.25, 0.5,0.75, 0.9], axis=0)

def plot_binary_gp_l1():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average ACRisk of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.209)
    for g, h, i, j, temp in [(1, 0, 0, 2, 0), (1, 1, 0, 2, 1), (1, 2, 0, 2, 2)]:
        if GP_binary_result_mean[g][h][i][j] is not None:
            LengthName = ['l=0.5', 'l=1', 'l=2'][h]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_binary_result_mean[g][h][i][j][0, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label=LengthName)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("ACRisk with different prior smoothness l ($\sigma_0=2)$", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/'+'GP_binary_l1'+'.pdf')
def plot_binary_gp_l2():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average Value of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0.488,0.53)
    for g, h, i, j, temp in [(1, 0, 1, 2, 0), (1, 1, 1, 2, 1), (1, 2, 1, 2, 2)]:
        if GP_binary_result_mean[g][h][i][j] is not None:
            LengthName = ['l=0.5', 'l=1', 'l=2'][h]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_binary_result_mean[g][h][i][j][1, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label=LengthName)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("Value with different prior smoothness l ($\sigma_0=2)$", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/'+'GP_binary_l2'+'.pdf')
def plot_binary_gp_sigma01():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average ACRisk of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.239)
    for g,h,i,j,temp in [(0,2,1,2,0),(1,2,1,2,1),(2,2,1,2,2)]:
        if GP_binary_result_mean[g][h][i][j] is not None:
            PriorName = ['$\sigma_0=1$', '$\sigma_0=2$', '$\sigma_0=4$'][g]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_binary_result_mean[g][h][i][j][0, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label=PriorName)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28) 
    ax1.set_title("ACRisk with different prior strength $\sigma_0$ (l=2)", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/'+'GP_binary_sigma01'+'.pdf')
def plot_binary_gp_sigma02():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average Value of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0.488,0.53)
    for g,h,i,j,temp in [(0,2,1,2,0),(1,2,1,2,1),(2,2,1,2,2)]:
        if GP_binary_result_mean[g][h][i][j] is not None:
            PriorName = ['$\sigma_0=1$', '$\sigma_0=2$', '$\sigma_0=4$'][g]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_binary_result_mean[g][h][i][j][1, :-2], c=mycolor, linestyle=mylinetype,
                     linewidth=3, label=PriorName)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("Value with different prior strength $\sigma_0$ ($l=2)$", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure_additional/'+'GP_binary_sigma02'+'.pdf')

plot_binary_gp_l1()
plot_binary_gp_l2()
plot_binary_gp_sigma01()
plot_binary_gp_sigma02()

